{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2022, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}

{# ------------------------------------------------------------------------- #}
{# Utilities for PyTorch code generation templates.                          #}
{# ------------------------------------------------------------------------- #}

{# Convert a class to the emitted string
 #
 # Args:
 #     T_or_value (type or Element):
 #     name (str): Name in case type is a generated struct
 #     is_input (bool): Is this an input argument or return value?
 #     available_classes (T.List[type]):  A list sym classes already available (meaning
 #       they should be referenced by just their name, and not, say, sym.Rot3).
 #}
{%- macro format_typename(T_or_value, name, is_input, available_classes = []) %}
    {%- set T = typing_util.get_type(T_or_value) -%}
    {%- if T.__name__ == 'DataBuffer' -%}
        {{ raise('PyTorch backend does not support DataBuffers, got DataBuffer for "{}"'.format(name)) }}
    {%- elif T.__name__ == 'Symbol' or is_symbolic(T_or_value) or T.__name__ == "float" -%}
        torch.Tensor
    {%- elif T.__name__ == 'NoneType' -%}
        None
    {%- elif issubclass(T, Matrix) -%}
        torch.Tensor
    {%- elif issubclass(T, Values) -%}
        {{ raise('PyTorch backend does not support Values inputs or outputs, got Values for "{}"'.format(name)) }}
    {%- elif is_sequence(T_or_value) -%}
        {%- if is_input -%}
            T.Sequence[{{ format_typename(T_or_value[0], name, is_input) }}]
        {%- else -%}
            T.List[float]
        {%- endif -%}
    {%- else -%}
        {{ raise('PyTorch backend does not support geo or cam types, got value "{}" of type `{}`'.format(name, T)) }}
    {%- endif -%}
{% endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Get the type of the object in the ouput Values with key given by spec.return_key
 #
 # Args:
 #     spec (Codegen):
 #     available_classes (T.List[type]):  A list sym classes already available (meaning
 #       they should be referenced by just their name, and not, say, sym.Rot3).
 #}
{%- macro get_return_type(spec, available_classes = []) %}
    {%- if spec.outputs.keys() | length == 1 -%}
        {%- set name, type = spec.outputs.items() | first -%}
        {{ format_typename(type, name, is_input=False, available_classes=available_classes) }}
    {%- elif spec.outputs -%}
        T.Tuple[
        {%- for name, type in spec.outputs.items() -%}
        {{ format_typename(type, name, is_input=False, available_classes=available_classes) }}{% if not loop.last %}, {% endif %}
        {%- endfor -%}]
    {%- else -%}
        None
    {%- endif -%}
{% endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Format function docstring
 #
 # Args:
 #     docstring (str):
 #}
{% macro print_docstring(docstring) %}
"""
{%- for line in docstring.split('\n') %}
{{ '{}'.format(line) }}
{% endfor %}

The types for inputs and outputs above are the shapes of the symbolic matrices this function was
generated from.  Tensors for both inputs and outputs are expected to have dimensions as follows:

- (..., N, M) for sf.Matrix types of shape (N, M) where N != 1 and M != 1
- (..., N) for sf.Matrix types of shape (N, M) where N != 1 and M == 1 (i.e. column vectors)
- (..., M) for sf.Matrix types of shape (N, M) where N == 1 and M != 1 (i.e. row vectors)
- (...) for sf.Matrix types of shape (1, 1) and sf.Scalar types

Where `...` is an arbitrary number of batch dimensions that can be present on the front of input
tensors.  The outputs will have the same batch dimensions as the inputs.

The tensor_kwargs argument can be used to set the device and dtype to use.  If not provided, it will
default to the device and dtype of the first input tensor, or the empty dict if there are no inputs.
"""
{% endmacro %}

{# ------------------------------------------------------------------------- #}
{# Generate function name and arguments:
 #
 # Args:
 #     spec (Codegen):
 #}
{%- macro function_name_and_args(spec) -%}
{{ spec.name }}(
{%- for name in spec.inputs.keys() -%}
{{ name }},{{ ' ' }}
{%- endfor -%}tensor_kwargs = None)
{%- endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Generate function declaration
 #
 # Args:
 #     spec (Codegen):
 #     is_method (bool): True if function is a method of a class,
 #     available_classes (T.List[type]):  A list sym classes already available (meaning
 #       they should be referenced by just their name, and not, say, sym.Rot3).
 #}
{%- macro function_declaration(spec, is_method = False, available_classes = []) -%}
{%- if is_method and "self" not in spec.inputs -%}
@staticmethod
{% endif %}
def {{ function_name_and_args(spec) }}:
    # type: (
    {%- for name, type in spec.inputs.items() -%}
    {{ format_typename(type, name, is_input=True, available_classes=available_classes) }},{{ ' ' }}
    {%- endfor -%}T.Optional[TensorKwargs]) -> {{ get_return_type(spec, available_classes=available_classes) }}
{%- endmacro -%}

{# ------------------------------------------------------------------------- #}

{# Generate inner code for computing the given expression.
 #
 # Args:
 #     spec (Codegen):
 #     available_classes (T.List[type]):  A list sym classes already available (meaning
 #       they should be referenced by just their name, and not, say, sym.Rot3).
 #}
{% macro expr_code(spec, available_classes = []) %}
    # Total ops: {{ spec.print_code_results.total_ops }}

    # Deduce expected tensor device and dtype if not provided
    if tensor_kwargs is None:
    {% if spec.inputs %}
    {% set first_arg = spec.inputs | first %}
        tensor_kwargs = {
            "device": {{ first_arg }}.device,
            "dtype": {{ first_arg }}.dtype
        }
    {% else %}
        tensor_kwargs = {}
    {% endif %}

    # Input arrays
    {% for name, type in spec.inputs.items() %}
        {% set T = typing_util.get_type(type) %}
        {% if not issubclass(T, Matrix) and not issubclass(T, Values) and not is_symbolic(type) and not is_sequence(type) and not T.__name__ == "float" %}
    _{{ name }} = {{ name }}.data
        {% endif %}
    {% endfor %}

    # Intermediate terms ({{ spec.print_code_results.intermediate_terms | length }})
    {% for lhs, rhs in spec.print_code_results.intermediate_terms %}
    {{ lhs }} = {{ rhs }}
    {% endfor %}

    # Output terms
    {# Render all non-sparse terms -#}
    {% for name, type, terms in spec.print_code_results.dense_terms %}
        {% set T = typing_util.get_type(type) %}
        {% if issubclass(T, Matrix) %}
            {% set rows = type.shape[0] %}
            {% set cols = type.shape[1] %}
            {% if cols == 1 %}
    _{{ name }} = (
            {% else %}
    _{{ name }} = _broadcast_and_stack([
            {% endif %}
            {% for j in range(cols) %}
                {% if rows == 1 %}
                {{ terms[j][1] }}{% if not loop.last %},{% endif %}
                {% else %}
                _broadcast_and_stack([
                {% for i in range(rows) %}
                    {# NOTE(brad): The order of the terms is the storage order of geo.Matrix. If the
                    storage order of geo.Matrix is changed (i.e., from column major to row major),
                    the index into terms will have to be changed to match that order. #}
                    {{ terms[j * rows + i][1] }}{% if not loop.last %},{% endif %}
                {% endfor %}
                ], dim=-1){% if not loop.last %},{% endif %}
                {% endif %}
            {% endfor %}
            {% if cols == 1 %}
    )
            {% else %}
    ], dim=-1)
            {% endif %}
        {% elif not (is_symbolic(type) or T.__name__ == "float") %}
            {{ raise('PyTorch backend does not support Sequence outputs, got output "{}" of type `{}`'.format(name, type)) }}
        {% else %}
    _{{name}} = {{ terms[0][1] }}
        {% endif %}
    {% endfor %}

    {% for name, type, terms in spec.print_code_results.sparse_terms %}
        {{ raise("PyTorch backend does not support sparse outputs: " + name) }}
    {% endfor %}

    return
    {%- for name, type in spec.outputs.items() %}
        {% set T = typing_util.get_type(type) %}
        {% if issubclass(T, Matrix) or is_symbolic(type) or T.__name__ == "float" %}
 _{{name}}
        {%- else -%}
            {{ raise('PyTorch backend only supports sf.Scalar and sf.Matrix outputs, got output "{}" of type `{}`'.format(name, T)) }}
        {%- endif %}
        {%- if not loop.last %}, {% endif %}
    {%- endfor -%}
{% endmacro %}
